from torch.utils.data import DataLoader, Dataset
from config import Config
import os
from sklearn.model_selection import StratifiedGroupKFold, train_test_split
from transformers import AutoTokenizer
import torch
import pandas as pd


class DataGenerator(Dataset):
    def __init__(self, df, config):
        self.inputs = df['input'].values.astype(str)
        self.targets = df['target'].values.astype(str)
        self.label = df['score'].values
        self.tokenizer = AutoTokenizer.from_pretrained(config['pretrain_model_path'])
        self.data = []
        self.load()

    def load(self):
        for i in range(len(self.inputs)):
            inputs = self.inputs[i]
            targets = self.targets[i]
            label = self.label[i]
            inputs = self.tokenizer.encode(inputs, max_length=20, padding='max_length', truncation=True)
            targets = self.tokenizer.encode(targets, max_length=20, padding='max_length', truncation=True)
            inputs = torch.LongTensor(inputs)
            targets = torch.LongTensor(targets)
            label = torch.FloatTensor([label])
            # label = torch.LongTensor([label])
            self.data.append([inputs, targets, label])
        return

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        return self.data[index]


def load_data(data_path, config, shuffle=True):
    df = pd.read_csv(data_path)
    dg = DataGenerator(df, config)
    dl = DataLoader(dg, batch_size=config['batch_size'], shuffle=shuffle)
    return dl


if __name__ == '__main__':
    # 数据预处理
    df = pd.read_csv('../usp_data/train.csv')
    df_title = pd.read_csv('../usp_data/titles.csv')
    df = df.merge(df_title, how='left', left_on='context', right_on='code')
    df = df[['id', 'anchor', 'target', 'context', 'score', 'title']]
    tokenizer = AutoTokenizer.from_pretrained(Config['pretrain_model_path'])
    df['input'] = df['anchor'] + tokenizer.sep_token + df['title'].apply(str.lower)
    # df = df.sample(2000)
    train_df, valid_df = train_test_split(df, test_size=0.1, random_state=42)
    test_df = valid_df
    # train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
    if not os.path.isdir('new_data'):
        os.mkdir('new_data/')
    print(len(train_df), len(valid_df), len(test_df))
    df.to_csv('new_data/train.csv', index=False)
    valid_df.to_csv('new_data/valid.csv', index=False)
    test_df.to_csv('new_data/test.csv', index=False)

    # 将数据分为５折，每次将４个作为训练集，将剩余１个作为测试集
    # kf = StratifiedGroupKFold(n_splits=2, shuffle=True, random_state=42)
    # df['fold'] = -1
    # for f, (t_, v_) in enumerate(kf.split(X=df, y=df['anchor'], groups=df['anchor'])):
    #     df.loc[v_, 'fold'] = f
    #
    #
    # if not os.path.isdir('backup'):
    #     os.mkdir('backup/')
    # df.to_excel('backup/test.xlsx', index=False)
